/*
 * Copyright 2024 Google LLC
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.google.firebase.ai

import android.graphics.Bitmap
import com.google.firebase.FirebaseApp
import com.google.firebase.ai.common.APIController
import com.google.firebase.ai.common.AppCheckHeaderProvider
import com.google.firebase.ai.common.CountTokensRequest
import com.google.firebase.ai.common.GenerateContentRequest
import com.google.firebase.ai.type.Content
import com.google.firebase.ai.type.CountTokensResponse
import com.google.firebase.ai.type.FinishReason
import com.google.firebase.ai.type.FirebaseAIException
import com.google.firebase.ai.type.GenerateContentResponse
import com.google.firebase.ai.type.GenerationConfig
import com.google.firebase.ai.type.GenerativeBackend
import com.google.firebase.ai.type.GenerativeBackendEnum
import com.google.firebase.ai.type.InvalidStateException
import com.google.firebase.ai.type.PromptBlockedException
import com.google.firebase.ai.type.RequestOptions
import com.google.firebase.ai.type.ResponseStoppedException
import com.google.firebase.ai.type.SafetySetting
import com.google.firebase.ai.type.SerializationException
import com.google.firebase.ai.type.Tool
import com.google.firebase.ai.type.ToolConfig
import com.google.firebase.ai.type.content
import com.google.firebase.appcheck.interop.InteropAppCheckTokenProvider
import com.google.firebase.auth.internal.InternalAuthProvider
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.flow.map
import kotlinx.serialization.ExperimentalSerializationApi

/**
 * Represents a multimodal model (like Gemini), capable of generating content based on various input
 * types.
 */
public class GenerativeModel
internal constructor(
  private val modelName: String,
  private val generationConfig: GenerationConfig? = null,
  private val safetySettings: List<SafetySetting>? = null,
  private val tools: List<Tool>? = null,
  private val toolConfig: ToolConfig? = null,
  private val systemInstruction: Content? = null,
  private val generativeBackend: GenerativeBackend = GenerativeBackend.googleAI(),
  private val controller: APIController,
) {
  internal constructor(
    modelName: String,
    apiKey: String,
    firebaseApp: FirebaseApp,
    useLimitedUseAppCheckTokens: Boolean,
    generationConfig: GenerationConfig? = null,
    safetySettings: List<SafetySetting>? = null,
    tools: List<Tool>? = null,
    toolConfig: ToolConfig? = null,
    systemInstruction: Content? = null,
    requestOptions: RequestOptions = RequestOptions(),
    generativeBackend: GenerativeBackend,
    appCheckTokenProvider: InteropAppCheckTokenProvider? = null,
    internalAuthProvider: InternalAuthProvider? = null
  ) : this(
    modelName,
    generationConfig,
    safetySettings,
    tools,
    toolConfig,
    systemInstruction,
    generativeBackend,
    APIController(
      apiKey,
      modelName,
      requestOptions,
      "gl-kotlin/${KotlinVersion.CURRENT}-ai fire/${BuildConfig.VERSION_NAME}",
      firebaseApp,
      AppCheckHeaderProvider(
        TAG,
        useLimitedUseAppCheckTokens,
        appCheckTokenProvider,
        internalAuthProvider
      ),
    ),
  )

  /**
   * Generates new content from the input [Content] given to the model as a prompt.
   *
   * @param prompt The input(s) given to the model as a prompt.
   * @return The content generated by the model.
   * @throws [FirebaseAIException] if the request failed.
   * @see [FirebaseAIException] for types of errors.
   */
  public suspend fun generateContent(
    prompt: Content,
    vararg prompts: Content
  ): GenerateContentResponse =
    try {
      controller.generateContent(constructRequest(prompt, *prompts)).toPublic().validate()
    } catch (e: Throwable) {
      throw FirebaseAIException.from(e)
    }

  /**
   * Generates new content from the input [Content] given to the model as a prompt.
   *
   * @param prompt The input(s) given to the model as a prompt.
   * @return The content generated by the model.
   * @throws [FirebaseAIException] if the request failed.
   * @see [FirebaseAIException] for types of errors.
   */
  public suspend fun generateContent(prompt: List<Content>): GenerateContentResponse =
    try {
      controller.generateContent(constructRequest(prompt)).toPublic().validate()
    } catch (e: Throwable) {
      throw FirebaseAIException.from(e)
    }

  /**
   * Generates new content as a stream from the input [Content] given to the model as a prompt.
   *
   * @param prompt The input(s) given to the model as a prompt.
   * @return A [Flow] which will emit responses as they are returned by the model.
   * @throws [FirebaseAIException] if the request failed.
   * @see [FirebaseAIException] for types of errors.
   */
  public fun generateContentStream(
    prompt: Content,
    vararg prompts: Content
  ): Flow<GenerateContentResponse> =
    controller
      .generateContentStream(constructRequest(prompt, *prompts))
      .catch { throw FirebaseAIException.from(it) }
      .map { it.toPublic().validate() }

  /**
   * Generates new content as a stream from the input [Content] given to the model as a prompt.
   *
   * @param prompt The input(s) given to the model as a prompt.
   * @return A [Flow] which will emit responses as they are returned by the model.
   * @throws [FirebaseAIException] if the request failed.
   * @see [FirebaseAIException] for types of errors.
   */
  public fun generateContentStream(prompt: List<Content>): Flow<GenerateContentResponse> =
    controller
      .generateContentStream(constructRequest(prompt))
      .catch { throw FirebaseAIException.from(it) }
      .map { it.toPublic().validate() }

  /**
   * Generates new content from the text input given to the model as a prompt.
   *
   * @param prompt The text to be send to the model as a prompt.
   * @return The content generated by the model.
   * @throws [FirebaseAIException] if the request failed.
   * @see [FirebaseAIException] for types of errors.
   */
  public suspend fun generateContent(prompt: String): GenerateContentResponse =
    generateContent(content { text(prompt) })

  /**
   * Generates new content as a stream from the text input given to the model as a prompt.
   *
   * @param prompt The text to be send to the model as a prompt.
   * @return A [Flow] which will emit responses as they are returned by the model.
   * @throws [FirebaseAIException] if the request failed.
   * @see [FirebaseAIException] for types of errors.
   */
  public fun generateContentStream(prompt: String): Flow<GenerateContentResponse> =
    generateContentStream(content { text(prompt) })

  /**
   * Generates new content from the image input given to the model as a prompt.
   *
   * @param prompt The image to be converted into a single piece of [Content] to send to the model.
   * @return A [GenerateContentResponse] after some delay.
   * @throws [FirebaseAIException] if the request failed.
   * @see [FirebaseAIException] for types of errors.
   */
  public suspend fun generateContent(prompt: Bitmap): GenerateContentResponse =
    generateContent(content { image(prompt) })

  /**
   * Generates new content as a stream from the image input given to the model as a prompt.
   *
   * @param prompt The image to be converted into a single piece of [Content] to send to the model.
   * @return A [Flow] which will emit responses as they are returned by the model.
   * @throws [FirebaseAIException] if the request failed.
   * @see [FirebaseAIException] for types of errors.
   */
  public fun generateContentStream(prompt: Bitmap): Flow<GenerateContentResponse> =
    generateContentStream(content { image(prompt) })

  /** Creates a [Chat] instance using this model with the optionally provided history. */
  public fun startChat(history: List<Content> = emptyList()): Chat =
    Chat(this, history.toMutableList())

  /**
   * Counts the number of tokens in a prompt using the model's tokenizer.
   *
   * @param prompt The input(s) given to the model as a prompt.
   * @return The [CountTokensResponse] of running the model's tokenizer on the input.
   * @throws [FirebaseAIException] if the request failed.
   * @see [FirebaseAIException] for types of errors.
   */
  public suspend fun countTokens(prompt: Content, vararg prompts: Content): CountTokensResponse {
    try {
      return controller.countTokens(constructCountTokensRequest(prompt, *prompts)).toPublic()
    } catch (e: Throwable) {
      throw FirebaseAIException.from(e)
    }
  }

  /**
   * Counts the number of tokens in a prompt using the model's tokenizer.
   *
   * @param prompt The input(s) given to the model as a prompt.
   * @return The [CountTokensResponse] of running the model's tokenizer on the input.
   * @throws [FirebaseAIException] if the request failed.
   * @see [FirebaseAIException] for types of errors.
   */
  public suspend fun countTokens(prompt: List<Content>): CountTokensResponse {
    try {
      return controller.countTokens(constructCountTokensRequest(*prompt.toTypedArray())).toPublic()
    } catch (e: Throwable) {
      throw FirebaseAIException.from(e)
    }
  }

  /**
   * Counts the number of tokens in a text prompt using the model's tokenizer.
   *
   * @param prompt The text given to the model as a prompt.
   * @return The [CountTokensResponse] of running the model's tokenizer on the input.
   * @throws [FirebaseAIException] if the request failed.
   * @see [FirebaseAIException] for types of errors.
   */
  public suspend fun countTokens(prompt: String): CountTokensResponse {
    return countTokens(content { text(prompt) })
  }

  /**
   * Counts the number of tokens in an image prompt using the model's tokenizer.
   *
   * @param prompt The image given to the model as a prompt.
   * @return The [CountTokensResponse] of running the model's tokenizer on the input.
   * @throws [FirebaseAIException] if the request failed.
   * @see [FirebaseAIException] for types of errors.
   */
  public suspend fun countTokens(prompt: Bitmap): CountTokensResponse {
    return countTokens(content { image(prompt) })
  }

  @OptIn(ExperimentalSerializationApi::class)
  private fun constructRequest(vararg prompt: Content) =
    GenerateContentRequest(
      modelName,
      prompt.map { it.toInternal() },
      safetySettings
        ?.also { safetySettingList ->
          if (
            generativeBackend.backend == GenerativeBackendEnum.GOOGLE_AI &&
              safetySettingList.any { it.method != null }
          ) {
            throw InvalidStateException(
              "HarmBlockMethod is unsupported by the Google Developer API"
            )
          }
        }
        ?.map { it.toInternal() },
      generationConfig?.toInternal(),
      tools?.map { it.toInternal() },
      toolConfig?.toInternal(),
      systemInstruction?.copy(role = "system")?.toInternal(),
    )

  private fun constructRequest(prompt: List<Content>) = constructRequest(*prompt.toTypedArray())

  private fun constructCountTokensRequest(vararg prompt: Content) =
    when (generativeBackend.backend) {
      GenerativeBackendEnum.GOOGLE_AI -> CountTokensRequest.forGoogleAI(constructRequest(*prompt))
      GenerativeBackendEnum.VERTEX_AI -> CountTokensRequest.forVertexAI(constructRequest(*prompt))
    }

  private fun GenerateContentResponse.validate() = apply {
    if (candidates.isEmpty() && promptFeedback == null) {
      throw SerializationException("Error deserializing response, found no valid fields")
    }
    promptFeedback?.blockReason?.let { throw PromptBlockedException(this) }
    candidates
      .mapNotNull { it.finishReason }
      .firstOrNull { it != FinishReason.STOP }
      ?.let { throw ResponseStoppedException(this) }
  }

  private companion object {
    private val TAG = GenerativeModel::class.java.simpleName
  }
}
